{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['PYSPARK_SUBMIT_ARGS'] = '--jars xgboost4j-spark-0.90.jar,xgboost4j-0.90.jar pyspark-shell'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import findspark\n",
    "findspark.init()\n",
    "master_url = \"yarn\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pyspark + xgboost test\n",
    "from pyspark.sql import SparkSession, types\n",
    "spark = SparkSession\\\n",
    "        .builder\\\n",
    "        .appName(\"PySpark XGBOOST Titanic\")\\\n",
    "        .master(master_url) \\\n",
    "        .getOrCreate()\n",
    "\n",
    "spark.sparkContext.addPyFile(\"sparkxgb.zip\")\n",
    "\n",
    "\n",
    "from sparkxgb import XGBoostClassifier\n",
    "xgboost = XGBoostClassifier(\n",
    "    featuresCol=\"features\", \n",
    "    labelCol=\"Survival\", \n",
    "    predictionCol=\"prediction\"\n",
    ")\n",
    "\n",
    "\n",
    "spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create spark session\n",
    "from pyspark.sql import SparkSession\n",
    "spark = SparkSession \\\n",
    "        .builder \\\n",
    "        .master(master_url) \\\n",
    "        .appName(\"Training\")\n",
    "spark = spark.getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "import pyspark\n",
    "from pyspark.sql import types\n",
    "from pyspark.sql.types import StructType, StructField\n",
    "\n",
    "def get_schema():\n",
    "    return StructType([\n",
    "        # user names\n",
    "        StructField(\"encoded_ghost_user_id\", types.StringType(), nullable=True),\n",
    "        StructField(\"encoded_ghost_friend_user_id\", types.StringType(), nullable=True),\n",
    "        # user feature\n",
    "        StructField(\"user_contact_book_size\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"user_snapchatter_in_contact_size\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"user_story_post_score_7d\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"user_snap_sent_score_7d\", types.IntegerType(), nullable=True),\n",
    "        # friend feature\n",
    "        StructField(\"friend_contact_book_size\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"friend_snapchatter_in_contact_size\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"friend_story_post_score_7d\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"friend_snap_sent_score_7d\", types.IntegerType(), nullable=True),\n",
    "        # edge feature\n",
    "        StructField(\"friend_counting_score\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"contact_book_similarity_score\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"contact_book_normalized_similarity_score\", types.IntegerType(), nullable=True),\n",
    "        StructField(\"is_friend_in_users_contact_book\", types.BooleanType(), nullable=True),\n",
    "        # labels\n",
    "        StructField(\"friend_addition_result\", types.BooleanType(), nullable=False),\n",
    "        StructField(\"friend_reciprocation_result\", types.BooleanType(), nullable=False),\n",
    "        StructField(\"total_snap_sent\", types.IntegerType(), nullable=False),\n",
    "        StructField(\"total_snap_sent_from_friend\", types.IntegerType(), nullable=False),\n",
    "        StructField(\"total_story_viewed\", types.StringType(), nullable=False),\n",
    "        StructField(\"total_story_viewed_from_friend\", types.StringType(), nullable=False),\n",
    "    ])\n",
    "\n",
    "data = spark.read.csv(\"gs://shuyi-test/reg_model/data.csv\",\n",
    "                      header=True,\n",
    "                      nullValue=\"MISSING_VALUE\",\n",
    "                      schema=get_schema())\n",
    "data = data.withColumn(\"friend_addition_result\", data[\"friend_addition_result\"].cast(types.DoubleType()))\n",
    "data = data.withColumn(\"friend_reciprocation_result\", data[\"friend_reciprocation_result\"].cast(types.DoubleType()))\n",
    "data = data[data.columns[:-2]] # drop story engagement data (last 2 columns)\n",
    "data = data.cache()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2524366\n",
      "2524366\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['encoded_ghost_user_id',\n",
       " 'encoded_ghost_friend_user_id',\n",
       " 'user_contact_book_size',\n",
       " 'user_snapchatter_in_contact_size',\n",
       " 'user_story_post_score_7d',\n",
       " 'user_snap_sent_score_7d',\n",
       " 'friend_contact_book_size',\n",
       " 'friend_snapchatter_in_contact_size',\n",
       " 'friend_story_post_score_7d',\n",
       " 'friend_snap_sent_score_7d',\n",
       " 'friend_counting_score',\n",
       " 'contact_book_similarity_score',\n",
       " 'contact_book_normalized_similarity_score',\n",
       " 'is_friend_in_users_contact_book',\n",
       " 'friend_addition_result',\n",
       " 'friend_reciprocation_result',\n",
       " 'total_snap_sent',\n",
       " 'total_snap_sent_from_friend']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This row is reserved for manually checking data only\n",
    "#data.describe().show()\n",
    "print(data.count())\n",
    "data.na.drop()\n",
    "print(data.count())\n",
    "data.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PipelineModel_c9c0771ff1c2"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train a random forest model\n",
    "from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "\n",
    "assembler = VectorAssembler(\n",
    "    inputCols=[\"user_contact_book_size\", \"user_snapchatter_in_contact_size\", \"user_story_post_score_7d\", \"user_snap_sent_score_7d\",\n",
    "               \"friend_contact_book_size\", \"friend_snapchatter_in_contact_size\", \"friend_story_post_score_7d\", \"friend_snap_sent_score_7d\",\n",
    "               \"friend_counting_score\", \"contact_book_similarity_score\", \"contact_book_normalized_similarity_score\", \"is_friend_in_users_contact_book\",\n",
    "              ],\n",
    "    outputCol=\"features\")\n",
    "\n",
    "rf = RandomForestClassifier(featuresCol=\"features\", labelCol=\"friend_addition_result\")\n",
    "\n",
    "pipeline = Pipeline(stages=[assembler, rf])\n",
    "pipeline.fit(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train a xgboost model\n",
    "from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler\n",
    "from pyspark.ml import Pipeline\n",
    "spark.sparkContext.addPyFile(\"sparkxgb.zip\") # read xgboost pyspark client lib\n",
    "from sparkxgb import XGBoostClassifier\n",
    "\n",
    "assembler = VectorAssembler(\n",
    "    inputCols=[\"user_contact_book_size\", \"user_snapchatter_in_contact_size\", \"user_story_post_score_7d\", \"user_snap_sent_score_7d\",\n",
    "               \"friend_contact_book_size\", \"friend_snapchatter_in_contact_size\", \"friend_story_post_score_7d\", \"friend_snap_sent_score_7d\",\n",
    "               \"friend_counting_score\", \"contact_book_similarity_score\", \"contact_book_normalized_similarity_score\", \"is_friend_in_users_contact_book\",\n",
    "              ],\n",
    "    outputCol=\"features\")\n",
    "\n",
    "xgboost = XGBoostClassifier(\n",
    "    objective=\"reg:logistic\",\n",
    "    maxDepth=3,\n",
    "    missing=float(0.0),\n",
    "    featuresCol=\"features\", \n",
    "    labelCol=\"friend_addition_result\", \n",
    ")\n",
    "\n",
    "# pipeline = Pipeline(stages=[assembler, xgboost])\n",
    "# trained_model = pipeline.fit(data)\n",
    "\n",
    "td = assembler.transform(data)\n",
    "trained_raw_model = xgboost.fit(td)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------+--------------------+--------------------+----------+\n",
      "|friend_addition_result|       rawPrediction|         probability|prediction|\n",
      "+----------------------+--------------------+--------------------+----------+\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.44295924901962...|[0.60896393656730...|       0.0|\n",
      "|                   0.0|[0.49293401837348...|[0.62079739570617...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   1.0|[0.44295924901962...|[0.60896393656730...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   1.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.44295924901962...|[0.60896393656730...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.26787239313125...|[0.56657052040100...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.36648640036582...|[0.59060969948768...|       0.0|\n",
      "|                   0.0|[0.44295924901962...|[0.60896393656730...|       0.0|\n",
      "|                   0.0|[0.49293401837348...|[0.62079739570617...|       0.0|\n",
      "+----------------------+--------------------+--------------------+----------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# result = trained_model.transform(data)\n",
    "# result.select([\"friend_addition_result\", \"rawPrediction\", \"probability\", \"prediction\"]).show()\n",
    "\n",
    "result = trained_raw_model.transform(td)\n",
    "result.select([\"friend_addition_result\", \"rawPrediction\", \"probability\", \"prediction\"]).show()\n",
    "\n",
    "trained_raw_model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save trained model to local disk\n",
    "trained_raw_model.nativeBooster.saveModel(\"outputmodel.xgboost\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
